{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## An Introduction to K-Means Clustering\n",
    "\n",
    "by Scott Hendrickson & Fiona Pigott\n",
    "\n",
    "### K-Means is for learning unknown categories\n",
    "\n",
    "K-means is a machine learning technique for learning unknown categories--in other words, a technique for *unsupervised* learning. K-means tries to group n-dimensional data into clusters, where the actual position of those clusters is unknown. \n",
    "\n",
    "### Basic Idea\n",
    "\n",
    "From the Wikipedia article on k-means clustering:\n",
    "\n",
    "> \"k-means clustering aims to partition n observations into k clusters in which each observation belongs to the cluster with the nearest mean, serving as a prototype of the cluster\"\n",
    "\n",
    "Basically, k-means assumes that for some sensible distance metric, it's possible to partition data into groups around the \"center\" (\"centroid\") of different naturally separated clusters in the data.  \n",
    "\n",
    "This concept can be very useful for separating datasets that came from separate generative processes, where the location of each dataset is pretty much unknown. It only works well if there is an expectation that the datasets are clustered around their means, and that the means would reasonably be different. A classic example of where k-means would *not* separate datasets well is when the datasets have different distibutions, but a similar mean. *Not* a good problem to apply k-means to: <img src=\"files/ring_clusters.png\" style=\"width: 300px;\">\n",
    "\n",
    "Good problem to apply k-means to: <img src=\"files/kmeans_clusters.jpg\" style=\"width: 300px;\">\n",
    "\n",
    "#### Question: \n",
    "**Is there an underlying stucture to my data? Does my data have defined categories that I don't know about? How can I identify whuich datapoint belongs to which category?**\n",
    "\n",
    "#### Solution: \n",
    "**For a selection of centers (centroids) of data clusters (and we'll talk about how to choose centroids), for each data point, label that data point with the centroid it is closest to.**\n",
    "\n",
    "### Algorithm\n",
    "\n",
    "0) Have a dataset that you want to sort into clusters  \n",
    "1) Choose a number of clusters that you're going to look for (there are ways to optimize this, but you have to fix it for the next step)  \n",
    "2) Guess at cluster membership for each data point (basically, for each data point, randomly assign it to a cluster)  \n",
    "3) Find the center (\"centroid\") of each cluster (with the data points that you've assigned to it)  \n",
    "4) For each centroid, find which data points are closest to it, and assign those data points to its cluster  \n",
    "5) Repeat 3 & 4 (re-evaluate centroids based on new cluster memberhip, then re-assign clusters based on new centroids)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0) A cloud of data in two dimensions\n",
    "\n",
    "#### Setting up an example of data that could be separated by k-means:\n",
    "First, we'll generate a synthetic dataset from two different spherical gaussian distributions, setting the spacing so that clouds of data overlap a litte."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Import some python libraries that we'll need\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import math\n",
    "import sys\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def make_data(n_points, n_clusters=2, dim=2, sigma=1):\n",
    "    x = [[] for i in range(dim)]\n",
    "    for i in range(n_clusters):\n",
    "        for d in range(dim):\n",
    "            x[d].extend([random.gauss(i*3,sigma) for j in range(n_points)])\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# make our synthetic data\n",
    "num_clusters = 2\n",
    "num_points = 100\n",
    "data_sample = make_data(num_points, num_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot our synthetic data\n",
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(*data_sample)\n",
    "ax.set_title(\"Sample dataset, {} points per cluster, {} clusters\".format(num_points,num_clusters))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) How might we identify the *two* clusters?\n",
    "\n",
    "#### We're going to set $k = 2$ before trying to use k-means to separate the clusters\n",
    "We happen to *know* that $k = 2$ because we k=just made up this data with two distributions. I'll talk a little at the end about how to guess $k$ for a real-world dataset.\n",
    "\n",
    "#### Because we created this example, we know the \"truth\"\n",
    "We know which data came from which distribution (this is what k-means is trying to discover).\n",
    "\n",
    "Here's the truth, just to compare:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot our synthetic data\n",
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(data_sample[0][0:100], data_sample[1][0:100])\n",
    "ax.scatter(data_sample[0][100:200], data_sample[1][100:200])\n",
    "ax.set_title(\"Sample dataset, {} points per cluster, {} clusters\".format(num_points,num_clusters))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Start by guessing the cluster membership\n",
    "\n",
    "In this case, guessing means \"randomly assign cluster membership.\" There are other heuristics that you could use to make an initial guess, but we won't get into those here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# each cluster membership is going to have a color label (\"red\" cluster, \"orange\" cluster, etc)\n",
    "co = [\"red\", \"orange\", \"yellow\", \"green\", \"purple\", \"blue\", \"black\",\"brown\"]\n",
    "def guess_clusters(x, n_clusters):\n",
    "    # req co list of identifiers\n",
    "    for i in range(len(x[0])):\n",
    "        return [ co[random.choice(range(n_clusters))] for i in range(len(x[0]))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now guess the cluster membership--simply by randomly assigning a cluster label \n",
    "# \"orange\" or \"red\" to each of the data points \n",
    "membership_2 = guess_clusters(data_sample,2)\n",
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(*data_sample, color=membership_2)\n",
    "ax.set_title(\"Data set drawn from 2 different 2D Gaussian distributions\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Find the center of a set of data points\n",
    "\n",
    "We'll need a way of determining the center of a set of data, after we make a guess at cluster membership.\n",
    "\n",
    "In this case, we'll find the centers of the two clusters that we guessed about."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def centroid(x):\n",
    "    return [[sum(col)/float(len(x[0]))] for col in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# function to select members of only one cluster\n",
    "def select_members(x, membership, cluster):\n",
    "    return [ [i for i,label in zip(dim, membership) if label == cluster] for dim in x ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(*select_members(data_sample, membership_2, \"red\"), color=\"red\")\n",
    "ax.scatter(*centroid(select_members(data_sample, membership_2, \"red\")), color=\"black\", marker=\"*\", s = 100)\n",
    "ax.set_title(\"Centroid of the 'red' cluster (black star)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(*select_members(data_sample, membership_2, \"orange\"), color=\"orange\")\n",
    "ax.scatter(*centroid(select_members(data_sample, membership_2, \"orange\")), color=\"black\", marker=\"*\", s = 100)\n",
    "ax.set_title(\"Centroid of the 'orange' cluster (black star)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4) Update membership of points to closest centroid\n",
    "\n",
    "#### Find distances (will use to find distances to the centroid, in this case):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def distance(p1, p2):\n",
    "    # odd... vectors are lists of lists with only 1 element in each dim\n",
    "    return math.sqrt(sum([(i[0]-j[0])**2 for i,j in zip(p1, p2)]))\n",
    "\n",
    "# here's the distance between two points, just to show how it works\n",
    "print(\"Distance between (-1,-1) and (2,3): {}\".format(distance([[-1],[-1]],[[2],[3]])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def reassign(x, centriods):\n",
    "    membership = []\n",
    "    for idx in range(len(x[0])):\n",
    "        min_d = sys.maxsize\n",
    "        cluster = \"\"\n",
    "        for c, vc in centriods.items():\n",
    "            dist = distance(vc, [[t[idx]] for t in x])\n",
    "            if dist < min_d:\n",
    "                min_d = dist\n",
    "                cluster = c\n",
    "        membership.append(cluster)\n",
    "    return membership"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cent_2 = {i:centroid(select_members(data_sample, membership_2, i)) for i in co[:2]}\n",
    "membership_2 = reassign(data_sample, cent_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(*data_sample, color=membership_2)\n",
    "ax.scatter(*cent_2[\"red\"], color=\"black\", marker=\"*\", s = 360)\n",
    "ax.scatter(*cent_2[\"orange\"], color=\"black\", marker=\"*\", s = 360)\n",
    "ax.scatter(*cent_2[\"red\"], color=\"red\", marker=\"*\", s = 200)\n",
    "ax.scatter(*cent_2[\"orange\"], color=\"orange\", marker=\"*\", s = 200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5) Put it all together so that we can iterate\n",
    "Now we're going to iterate--assign clusters, finda  centroid, reassign clusters--until the centroid positions stop changing very much."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# function\n",
    "def get_centroids(x, membership):\n",
    "    return {i:centroid(select_members(x, membership, i)) for i in set(membership)}\n",
    "\n",
    "# redefine with total distance measure\n",
    "def reassign(x, centroids):\n",
    "    membership, scores = [], {}\n",
    "    # step through all the vectors\n",
    "    for idx in range(len(x[0])):\n",
    "        min_d, cluster = sys.maxsize, None # set the min distance to a large number (we're about to minimize it)\n",
    "        for c, vc in centroids.items():\n",
    "            # get the sum of the distances from each point in the cluster to the centroids\n",
    "            dist = distance(vc, [[t[idx]] for t in x])\n",
    "            if dist < min_d:\n",
    "                min_d = dist\n",
    "                cluster = c\n",
    "        # score is the minumum distance from each point in a cluster to the centroid of that cluster\n",
    "        scores[cluster] = min_d + scores.get(cluster, 0)\n",
    "        membership.append(cluster)\n",
    "        # retrun the membership & the sum of all the score over all of the clusters\n",
    "    return membership, sum(scores.values())/float(len(x[0]))\n",
    "\n",
    "def k_means(data, k):\n",
    "    # start with random distribution\n",
    "    membership = guess_clusters(data, k)\n",
    "    score, last_score = 0.0, sys.maxsize\n",
    "    while abs(last_score - score) > 1e-7:\n",
    "        last_score = score\n",
    "        c = get_centroids(data, membership)\n",
    "        membership, score = reassign(data, c)\n",
    "        #print(last_score - score)\n",
    "    return membership, c, score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mem, cl, s = k_means(data_sample, 2)\n",
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.scatter(*data_sample, color = mem)\n",
    "for i, pt in cl.items():\n",
    "    ax.scatter(*pt, color=\"black\", marker=\"*\", s = 16*8)\n",
    "ax.set_title(\"Clustering from k-means\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### K-means with real data\n",
    "\n",
    "#### Figuring out how many clusters to look for (that pesky \"Step 1\")\n",
    "Now, one thing we haven't covered yet is how to decide on the number of clusters to look for in the first place. There are several different heuristics that we can use to figure out what the \"best\" number of clusters is (we go into this more in https://github.com/DrSkippy/Data-Science-45min-Intros/tree/master/choosing-k-in-kmeans).\n",
    "\n",
    "The one heuristic that we're going to talk about here is finding the \"knee\" in the k-means error function.\n",
    "\n",
    "#### The error function:\n",
    "In this case, the error function is simply the sum of all of the distances from each data point to its assigned cluster, summed over all of the clusters. The further each data point is from its assigned cluster, the larger this error score is.\n",
    "\n",
    "#### Look for the \"knee\":\n",
    "When I say \"knee\" I mean to look for a bend in the graoh of the error score vs $k$. The idea is to find the place where you get a smaller decrease in the error (distance from each data point to a centroid) for every increase in the number of clusters ($k$). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "err = []\n",
    "trial_ks = range(1,5)\n",
    "results = {}\n",
    "for k in trial_ks:\n",
    "    mem_2, cl_2, s_2 = k_means(data_sample, k)\n",
    "    results[k] = mem_2\n",
    "    err.append(s_2)\n",
    "    \n",
    "f, axes = plt.subplots(1, len(trial_ks), sharey=True, figsize = (18,4))\n",
    "for i,k in enumerate(trial_ks):\n",
    "    axes[i].set_aspect('equal')\n",
    "    axes[i].set_title(\"k-means results with k = {} \\n error = {:f}\".format(k, err[i]))\n",
    "    axes[i].scatter(*data_sample, color = results[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the error as a function of the number of clusters\n",
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.plot(trial_ks,err,'o--')\n",
    "ax.set_title(\"Error as a funtion of k\")\n",
    "ax.xaxis.set_ticks(trial_ks)\n",
    "_ = ax.set_xlabel(\"number of clusters (k)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# a different example, this time with 4 clusters\n",
    "ex4 = make_data(200, 4) \n",
    "err4 = []\n",
    "trial_ks_4 = range(1,9)\n",
    "results_4 = {}\n",
    "for k in trial_ks_4:\n",
    "    mem_ex4, cl_ex4, s_ex4 = k_means(ex4, k)\n",
    "    results_4[k] = mem_ex4\n",
    "    err4.append(s_ex4)\n",
    "    \n",
    "f, axes = plt.subplots(2, int(len(trial_ks_4)/2), sharey=True, figsize = (18,11))\n",
    "for i,k in enumerate(trial_ks_4):\n",
    "    axes[int(i >= 4)][i%4].set_aspect('equal')\n",
    "    axes[int(i >= 4)][i%4].set_title(\"k-means results with k = {} \\n error = {:f}\".format(k, err4[i]))\n",
    "    axes[int(i >= 4)][i%4].scatter(*ex4, color = results_4[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the error as a function of the number of clusters\n",
    "fig = plt.figure(figsize = [6,6])\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_aspect('equal')\n",
    "ax.plot(trial_ks_4,err4,'o--')\n",
    "ax.set_title(\"Error as a funtion of k\")\n",
    "ax.xaxis.set_ticks(trial_ks_4)\n",
    "_ = ax.set_xlabel(\"number of clusters (k)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Choose the number that seems to make sense\n",
    "\n",
    "The \"number that makes sense\" is where the \"knee\" in the error function occurs. Basically, where we start getting little or no decrease in the error, even as we increase $k$.\n",
    "\n",
    "Here, you might see that after $k=4$, the error stops decreasing very quickly, and you might choose $k=4$ as the best number of clusters in the data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### K-means in the wild: don't use this code. Use Scikit Learn!\n",
    "\n",
    "The tutorial on using Scikit-learn (https://github.com/DrSkippy/Data-Science-45min-Intros/tree/master/sklearn-101) and also the one on finding the number of clusters (https://github.com/DrSkippy/Data-Science-45min-Intros/tree/master/choosing-k-in-kmeans) both use the Scikit-learn K-means algorithm, which has a well-thought-out API and is way faster than a homemade version. Try it out!\n",
    "\n",
    "### Thanks for playing!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
